import contextlib
import os
import random
from pathlib import Path

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
from diffusers import StableDiffusionControlNetPipeline, StableDiffusionPipeline, \
    StableDiffusionDepth2ImgPipeline, ControlNetModel
from torchvision.transforms import InterpolationMode

from .controlnet_utils import CONTROLNET_DICT

FRAME_EXT = [".jpg", ".png"]


def init_model(device="cuda", sd_version="1.5", model_key=None, control_type="none", model_path=None,
               weight_dtype="fp16"):
    use_depth = False
    if model_key is None:
        if sd_version == '2.1':
            model_key = "stabilityai/stable-diffusion-2-1-base"
        elif sd_version == '2.0':
            model_key = "stabilityai/stable-diffusion-2-base"
        elif sd_version == '1.5':
            model_key = "runwayml/stable-diffusion-v1-5"
        elif sd_version == '1.4':
            model_key = "CompVis/stable-diffusion-v1-4"
        elif sd_version == 'depth':
            model_key = "stabilityai/stable-diffusion-2-depth"
            use_depth = True
        else:
            raise ValueError(
                f'Stable-diffusion version {sd_version} not supported.')

        print(f'[INFO] loading stable diffusion from: {model_key}')
    else:
        print(f'[INFO] loading custome model from: {model_key}')

    if weight_dtype == "fp16":
        weight_dtype = torch.float16
    else:
        weight_dtype = torch.float32

    if control_type not in ["none", "pnp"]:
        controlnet_key = CONTROLNET_DICT[control_type]
        print(f'[INFO] loading controlnet from: {controlnet_key}')
        controlnet = ControlNetModel.from_pretrained(controlnet_key, torch_dtype=weight_dtype, cache_dir=model_path)
        print(f'[INFO] loaded controlnet!')
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            model_key, controlnet=controlnet, torch_dtype=weight_dtype, cache_dir=model_path
        )
    elif use_depth:
        pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
            model_key, torch_dtype=weight_dtype, cache_dir=model_path
        )
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            model_key, torch_dtype=weight_dtype, cache_dir=model_path,
        )

    return pipe.to(device), model_key


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def load_image(image_paths):
    images = []
    for image_path in image_paths:
        image = T.Resize(512)(Image.open(image_path).convert('RGB'))
        images += [T.ToTensor()(image)]
    return torch.stack(images)


def save_img(img, path):
    if Path(path).exists():
        os.remove(path)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    T.ToPILImage()(img.squeeze()).save(path)


# From pix2video: code/file_utils.py

def load_depth(model, depth_path: Path, input_image, dtype=torch.float32):
    if depth_path.exists():
        depth_map = torch.load(depth_path, map_location="cpu", weights_only=False)
    else:
        input_image = T.ToPILImage()(input_image.squeeze())
        depth_map = prepare_depth_map(model, input_image, dtype=dtype, device=model.device)
        # make sure parants directory exists
        depth_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(depth_map, depth_path)
        depth_image = (((depth_map + 1.0) / 2.0) * 255).to(torch.uint8)
        T.ToPILImage()(depth_image.squeeze()).convert("L").save(str(depth_path).replace(".pt", ".png"))

    return depth_map


@torch.no_grad()
def prepare_depth_map(model, image, depth_map=None, batch_size=1, do_classifier_free_guidance=False,
                      dtype=torch.float32, device="cuda"):
    if isinstance(image, Image.Image):
        image = [image]
    else:
        image = list(image)

    if isinstance(image[0], Image.Image):
        width, height = image[0].size
    elif isinstance(image[0], np.ndarray):
        width, height = image[0].shape[:-1]
    else:
        height, width = image[0].shape[-2:]

    if depth_map is None:
        pixel_values = model.feature_extractor(
            images=image, return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(device=device)
        # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
        # So we use `torch.autocast` here for half precision inference.
        context_manger = torch.autocast(
            "cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
        with context_manger:
            ret = model.depth_estimator(pixel_values)
            depth_map = ret.predicted_depth
            # depth_image = ret.depth
    else:
        depth_map = depth_map.to(device=device, dtype=dtype)

    indices = depth_map != -1
    bg_indices = depth_map == -1
    min_d = depth_map[indices].min()

    if bg_indices.sum() > 0:
        depth_map[bg_indices] = min_d - 10
        # min_d = min_d - 10

    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(height // model.vae_scale_factor,
              width // model.vae_scale_factor),
        mode="bicubic",
        align_corners=False,
    )

    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
    depth_map = depth_map.to(dtype)

    # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
    if depth_map.shape[0] < batch_size:
        repeat_by = batch_size // depth_map.shape[0]
        depth_map = depth_map.repeat(repeat_by, 1, 1, 1)

    depth_map = torch.cat(
        [depth_map] * 2) if do_classifier_free_guidance else depth_map
    return depth_map


def combine_images(directory):
    # delete the combined image before proceeding
    if os.path.exists(directory.joinpath("combined_image.png")):
        os.remove(directory.joinpath("combined_image.png"))

    # Get a list of all .png images in the directory
    image_paths = list(Path(directory).glob("*.png"))

    # Sort images by last modified time in descending order
    image_paths.sort(key=lambda x: x.stat().st_mtime, reverse=False)

    # Open all images
    images = [Image.open(image_path) for image_path in image_paths]

    # Set up font (default font, may need to be changed depending on system)
    try:
        font = ImageFont.load_default()
    except:
        font = None

    # Calculate the total width and max height for the combined image, including labels
    label_height = 20  # Height for labels, you can adjust as needed
    total_width = sum(image.width for image in images)
    max_height = max(image.height for image in images) + label_height

    # Create a new blank image with the total width and max height
    combined_image = Image.new("RGB", (total_width, max_height), (255, 255, 255))

    # Paste images and add labels
    draw = ImageDraw.Draw(combined_image)
    x_offset = 0
    for img, image_path in zip(images, image_paths):
        combined_image.paste(img, (x_offset, 0))

        # Draw the label (filename without extension) below each image
        label = image_path.stem  # Get filename without extension
        if font:
            text_width, text_height = draw.textbbox((0, 0), label, font=font)[2:4]
        else:
            text_width, text_height = draw.textsize(label)  # Fallback if no font

        text_x = x_offset + (img.width - text_width) // 2
        text_y = img.height + (label_height - text_height) // 2
        draw.text((text_x, text_y), label, fill="black", font=font)

        x_offset += img.width

    # Save the combined image
    combined_image.save(directory.joinpath("combined_image.png"))


def mask_decode(encoded_mask, image_shape=[512, 512]):
    length = image_shape[0] * image_shape[1]
    mask_array = np.zeros((length,))

    for i in range(0, len(encoded_mask), 2):
        splice_len = min(encoded_mask[i + 1], length - encoded_mask[i])
        for j in range(splice_len):
            mask_array[encoded_mask[i] + j] = 1

    mask_array = mask_array.reshape(image_shape[0], image_shape[1])
    # to avoid annotation errors in boundary
    mask_array[0, :] = 1
    mask_array[-1, :] = 1
    mask_array[:, 0] = 1
    mask_array[:, -1] = 1

    return mask_array

def load_512(image_path, left=0, right=0, top=0, bottom=0):
    if type(image_path) is str:
        image = np.array(Image.open(image_path))[:, :, :3]
    else:
        image = image_path
    h, w, c = image.shape
    left = min(left, w-1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h-bottom, left:w-right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = np.array(Image.fromarray(image).resize((512, 512)))
    return image
